# CelebA image generation using Conditional DCGAN
import copy
import os

import torch
import torchvision.transforms as transforms
import pickle

from Causal_MNIST_Images.DigitImageGeneration.mnist_image_generation import plot_dataset_digits
from Causal_MNIST_Images.DigitImageGeneration.morphomnist import io
from Causal_MNIST_Images.EvaluateCausalGAN import evaluate_after_epochs

from Causal_MNIST_Images.mnistControllerModel import get_discriminators, get_generated_labels

from Causal_MNIST_Images.mnistControllerModel import get_generators
from Causal_MNIST_Images.GroundTruth.CausalGraph_Mnist import set_nonid_mnist_images_minimized
from ModularUtils.ControllerConstants import get_multiple_labels_fill, fill2d_to_fill4d
from ModularUtils.Experiment_Class import Experiment
from ModularUtils.FunctionsConstant import asKey
from ModularUtils.FunctionsTraining import labels_image_gradient_penalty, save_checkpoint


def get_dataset(Exp, label, dno):
    dataset = []
    for feature in ["feature"]:
        file_name = Exp.file_roots[dno] + label + feature + ".pkl"

        with open(file_name, 'rb') as fp:
            label_data = pickle.load(fp)
        label_data = torch.FloatTensor(label_data)
        label_size = len(label_data)
        dataset.append(label_data.view(label_size, 1))

    result_dataset = torch.cat(dataset, 1).to(Exp.DEVICE)
    print(result_dataset.shape)
    return result_dataset


def Imagetrain_CausalController(Exp, cur_mechs, label_generators, G_optimizers, label_discriminator, D_optimizer,
                                dataset_dict_batches, imagedata_dict_batches, batchno):
    G_loss=torch.zeros(1).to(Exp.DEVICE)

    for interv_no, (intv_key_tup, dataset_batches) in enumerate(dataset_dict_batches.items()):
        intv_key = dict(intv_key_tup)

        data_input = dataset_batches[batchno]
        intervened_Var = list(intv_key.keys())
        compare_Var=[]
        for mech in cur_mechs:
            if mech in Exp.image_labels:
                continue
            ret = [lb for lb in Exp.train_mech_dict[mech][interv_no]["compare"] if not lb in compare_Var]
            compare_Var+= ret

        mini_batch = data_input.size()[0]
        indices = [Exp.label_names.index(lb) for lb in compare_Var]
        current_real_label = data_input[:, indices].type(torch.LongTensor).view(-1, len(indices)).to(Exp.DEVICE)

        dims_list = [Exp.label_dim[lb]["feature"] for lb in compare_Var]

        real_labels_fill = get_multiple_labels_fill(Exp, current_real_label, dims_list, isImage_labels=True,
                                                    more_dimsize=Exp.IMAGE_SIZE)  # !!!

        obs_images = imagedata_dict_batches[intv_key_tup][batchno]
        # we dont need real_labels_fill_ig cz will be feeding fake generated labels in image generator

        # generating fake data
        intv_tensor_dict = {}
        for lbid, intv_lb in enumerate(intervened_Var): #if no intervention then no looping
            index = [Exp.label_names.index(intv_lb)]
            parent_intv_label = data_input[:, index].type(torch.LongTensor).view(-1, 1).to(Exp.DEVICE) #for each intv parent
            dims_list = [Exp.label_dim[intv_lb]["feature"]]
            intv_parent_fill = get_multiple_labels_fill(Exp, parent_intv_label, dims_list, isImage_labels=False)
            intv_tensor_dict[intv_lb] = intv_parent_fill

        generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_tensor_dict, compare_Var+[Exp.image_labels[0]], mini_batch, hard=True)
        generated_image = generated_labels_dict[Exp.image_labels[0]]
        del generated_labels_dict[Exp.image_labels[0]]
        y_dims = sum([Exp.label_dim[lb]["feature"] for lb in compare_Var])
        ret = list(generated_labels_dict.values())
        ret2d = torch.cat(ret, 1).view(-1, y_dims) #for critic
        dims_list = [Exp.label_dim[lb]["feature"] for lb in compare_Var]
        generated_labels_fill= fill2d_to_fill4d(Exp, ret2d, more_dimsize=Exp.IMAGE_SIZE)

        # generated_labels_ig = torch.cat(ret, 1).view(-1, y_dims, 1, 1)  #for generator

        # image generation
        # Noises = []
        # image_noise = torch.randn(Exp.batch_size, Exp.IMAGE_NOISE_DIM).view(-1, Exp.IMAGE_NOISE_DIM, 1, 1).to(
        #     Exp.DEVICE)
        # Noises.append(image_noise)
        # generated_image = label_generators[cur_mechs[-1]](Noises, [generated_labels_ig])



        D_losses = []
        for crit_ in range(Exp.CRITIC_ITERATIONS):
            D_real_decision_obs = label_discriminator[interv_no](obs_images, real_labels_fill).squeeze()
            D_fake_decision_obs = label_discriminator[interv_no](generated_image, generated_labels_fill).squeeze()


            gp_obs = labels_image_gradient_penalty(label_discriminator[interv_no], obs_images, real_labels_fill, generated_image,
                                                   generated_labels_fill,
                                                   device=Exp.DEVICE)

            l1= torch.mean(D_real_decision_obs)
            l2= torch.mean(D_fake_decision_obs)
            l3=  Exp.LAMBDA_GP * gp_obs
            D_loss_obs = (-  (l1 - l2) + l3)

            D_losses.append(D_loss_obs.data)  # just a loss list
            label_discriminator[interv_no].zero_grad()
            D_loss_obs.backward(retain_graph=True)
            D_optimizer[interv_no].step()

        # accumulating the generator losses for all interventions.
        D_fake_decision_obs = label_discriminator[interv_no](generated_image, generated_labels_fill).squeeze()
        G_loss += -torch.mean(D_fake_decision_obs)


    # Back propagation
    for mech in cur_mechs:
        label_generators[mech].zero_grad()

    G_loss.backward()

    for mech in cur_mechs:
        G_optimizers[mech].step()

    D_loss = torch.mean(torch.FloatTensor(D_losses))  # just mean of losses


    return G_loss.data, D_loss.data



def imageMain(Exp, cur_mechs, label_generators, G_optimizers, discriminators, D_optimizers, dataset_dict,
              imagedata_dict, tvd_diff, kl_diff):
    dataset_dict_batches = {}
    num_batches = 0
    for key, each_dataset in dataset_dict.items():
        real_dataloader = torch.utils.data.DataLoader(dataset=each_dataset,
                                                      batch_size=Exp.batch_size,
                                                      shuffle=False)

        batch_list = []
        for data_input in real_dataloader:
            data_input = torch.squeeze(data_input)
            batch_list.append(data_input)

        dataset_dict_batches[key] = batch_list
        num_batches = len(batch_list)
    #

    imagedata_dict_batches = {}
    for key, each_dataset in imagedata_dict.items():
        image_data_loader = torch.utils.data.DataLoader(dataset=each_dataset,
                                                        batch_size=Exp.batch_size,
                                                        shuffle=False)

        batch_list = []
        for data_input in image_data_loader:
            data_input = torch.squeeze(data_input)
            batch_list.append(data_input)

        imagedata_dict_batches[key] = batch_list
        num_batches2 = len(batch_list)

    iteration = 0

    intv_batch = None
    for batchno in range(num_batches):

        G_loss, D_loss = Imagetrain_CausalController(Exp, cur_mechs, label_generators, G_optimizers, discriminators,
                                                                                                     D_optimizers,
                                                                                                     dataset_dict_batches,
                                                                                                     imagedata_dict_batches,
                                                                                                     batchno)

        # for id, img in enumerate(obs_images):
        #     imggg1 = img.permute(1, 2, 0).detach().cpu().numpy()
        #     fig, ax = plt.subplots()
        #     ax.set_title(f'Real {id}')
        #     plt.imshow(imggg1)
        #     plt.show()

        print('Epoch [%d/%d], Step [%d/%d],' % (
            Exp.curr_epoochs + 1, Exp.num_epochs, iteration + 1, len(real_dataloader)),
              'mechanism: ', cur_mechs, 'D_loss: %.4f, G_loss: %.4f' % (D_loss.data, G_loss.data))

        # Annealing
        tot_iter = Exp.curr_epoochs * len(real_dataloader) + iteration
        if (tot_iter % 100 == 0):
            Exp.anneal_temperature(tot_iter)

        if (iteration + 1) % int(num_batches / Exp.PLOTS_PER_EPOCH) == 0:
            tvd_diff, kl_diff = evaluate_after_epochs(Exp, cur_mechs, label_generators, dataset_dict, tvd_diff, kl_diff)

        Exp.D_avg_losses.append(torch.mean(D_loss))
        Exp.G_avg_losses.append(torch.mean(G_loss))
        iteration += 1

    if (Exp.curr_epoochs + 1) % 5 == 0:
        var_list= "".join(x for x in cur_mechs)
        save_checkpoint(Exp, Exp.SAVED_PATH, cur_mechs, label_generators, G_optimizers, {var_list:discriminators}, {var_list: D_optimizers})  # change this
        print("saved at ", Exp.SAVED_PATH)

    return 100


if __name__ == "__main__":

    Exp = Experiment("Exp1", set_nonid_mnist_images_minimized,
                     dist_thresh=0.15,
                     causal_hierarchy=2,
                     Temperature=1,
                     temp_min=0.01,
                     NOISE_DIM=128,
                     CONF_NOISE_DIM=128,
                     G_hid_dims=[256, 256],
                     D_hid_dims=[256, 256, 256],
                     IMAGE_FILTERS=[128, 64, 32],
                     CRITIC_ITERATIONS=1,
                     LAMBDA_GP=10,
                     learning_rate=2 * 1e-4,
                     Synthetic_Sample_Size=20000,
                     intv_Sample_Size=20000,
                     batch_size=200,
                     features=["feature"],
                     noise_states=100,
                     latent_state=16,
                     # Data_intervs=[{}, {"X1": 0}, {"X1": 1}],
                     Data_intervs=[{}],
                     num_epochs=300,
                     new_experiment=True
                     )

    print(Exp.Data_intervs)
    Exp.intv_batch_size = Exp.batch_size

    os.makedirs(Exp.SAVED_PATH, exist_ok=True)
    dag_name = Exp.Complete_DAG_desc + ".txt"

    # Exp.LOAD_MODEL_PATH = "/path_to_project/SAVED_EXPERIMENTS/nonid_mnist_images/Exp1/Sep_23_2022-10_45"
    # Exp.load_which_models = {"X1": True, "X2": True, "W": True, "ImgYdigit1": True, "ImgYdigit2":True}
    Exp.load_which_models = {"X1": False, "X2": False, "W": False, "ImgYdigit1": False, "ImgYdigit2":False}
    cur_mechs = ["X1", "X2", "W", "ImgYdigit1"]
    discrete_mechs= cur_mechs[0:-1]
    # cur_mechs = ["X1", "X2", "W", "ImgYdigit2"] or

    label_generators, optimizersMech = get_generators(Exp, Exp.load_which_models)

    discriminatorsMech, doptimizersMech = get_discriminators(Exp, cur_mechs, discrete_mechs,
                                                             Exp.load_which_models)  # update for interventional training

    dataset_dict = {}

    for dno in range(Exp.num_datasets):
        each_dataset = []
        for label in Exp.label_names:
            if label not in Exp.image_labels:
                each_dataset.append(get_dataset(Exp, label, dno))

        dataset_dict[asKey(Exp.Data_intervs[dno])] = torch.cat(each_dataset, 1).to(Exp.DEVICE)

    # image loading
    imagedata_dict = {}
    for dno in range(Exp.num_datasets):
        each_dataset = []
        # image dataset load

        loaded_images = io.load_idx(Exp.file_roots[dno] + "Ydigit1images.gz")
        transform = transforms.Compose([transforms.ToPILImage(),
                                        # transforms.Scale(Exp.IMAGE_SIZE),
                                        transforms.ToTensor()
                                        # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
                                        ])  # need to normalize too since I am using tanh?

        # for img in loaded_images:
        #     plot_dataset_digits((img, "x"))
        result_images = [torch.unsqueeze(transform(img), dim=0).to(Exp.DEVICE) for img in loaded_images]

        # for id, img in enumerate(result_images):
        #     imggg1 = img.permute(1, 2, 0).detach().cpu().numpy()
        #     fig, ax = plt.subplots()
        #     plt.imshow(imggg1)
        #     plt.show()

        imagedata_dict[asKey(Exp.Data_intervs[dno])] = torch.cat(result_images, 0)

    mech_tvd = 0
    cur_mech_epoch = 0
    tvd_diff = {}
    kl_diff = {}

    for epoch in range(Exp.num_epochs):
        Exp.curr_epoochs = epoch
        mech_tvd = imageMain(Exp, cur_mechs, label_generators, optimizersMech, discriminatorsMech, doptimizersMech,
                             dataset_dict, imagedata_dict, tvd_diff, kl_diff)
